{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "0eea5b13-8d0c-457f-b47d-28ee01980852",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "import os\n",
    "import cv2\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "19cadfe8-3461-4273-a34a-981e0be02cb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = torchvision.models.mobilenet_v2(pretrained=True)\n",
    "num_ftrs = models.classifier[1].in_features\n",
    "models.classifier[1] = nn.Sequential(\n",
    "    nn.Linear(num_ftrs, 76, bias=True),\n",
    "    nn.Sigmoid()\n",
    ")\n",
    "models = models.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "9a99081e-7178-4c43-8cea-d39a3506d656",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[237, 240, 242, ..., 255, 255, 255],\n",
       "        [238, 240, 242, ..., 255, 255, 255],\n",
       "        [238, 240, 243, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [234, 234, 234, ..., 255, 255, 255],\n",
       "        [233, 233, 233, ..., 255, 255, 255],\n",
       "        [237, 234, 224, ..., 254, 254, 254]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[221, 224, 227, ..., 248, 250, 252],\n",
       "        [223, 224, 227, ..., 249, 251, 252],\n",
       "        [225, 225, 226, ..., 251, 251, 251],\n",
       "        ...,\n",
       "        [220, 215, 217, ..., 254, 255, 255],\n",
       "        [222, 215, 215, ..., 255, 255, 255],\n",
       "        [210, 217, 220, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 254, 254, 254],\n",
       "        [255, 255, 255, ..., 254, 254, 254],\n",
       "        [255, 255, 255, ..., 254, 254, 254],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[214, 219, 215, ..., 253, 253, 253],\n",
       "        [215, 219, 215, ..., 253, 253, 253],\n",
       "        [215, 220, 215, ..., 253, 253, 253],\n",
       "        ...,\n",
       "        [206, 212, 208, ..., 238, 236, 235],\n",
       "        [211, 213, 206, ..., 239, 236, 234],\n",
       "        [208, 212, 211, ..., 235, 235, 240]], dtype=uint8),\n",
       " array([[254, 254, 254, ..., 254, 254, 254],\n",
       "        [254, 254, 254, ..., 254, 254, 254],\n",
       "        [254, 254, 254, ..., 254, 254, 254],\n",
       "        ...,\n",
       "        [229, 238, 239, ..., 219, 226, 224],\n",
       "        [228, 238, 239, ..., 219, 226, 220],\n",
       "        [232, 241, 241, ..., 221, 225, 233]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [231, 235, 234, ..., 222, 229, 227],\n",
       "        [230, 231, 229, ..., 223, 231, 228],\n",
       "        [235, 232, 228, ..., 221, 235, 231]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [230, 231, 232, ..., 231, 231, 235],\n",
       "        [229, 231, 233, ..., 229, 230, 235],\n",
       "        [227, 231, 231, ..., 244, 239, 242]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [236, 239, 238, ..., 246, 245, 247],\n",
       "        [235, 239, 237, ..., 246, 245, 250],\n",
       "        [239, 245, 241, ..., 249, 248, 247]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 250]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 253, 252, ..., 255, 255, 255],\n",
       "        [254, 249, 248, ..., 255, 255, 255],\n",
       "        [254, 249, 248, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 254, 254, 254],\n",
       "        [255, 255, 255, ..., 254, 254, 254],\n",
       "        [255, 255, 255, ..., 253, 254, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 254, 254, 254],\n",
       "        [255, 255, 255, ..., 254, 254, 254],\n",
       "        [255, 255, 255, ..., 254, 254, 254],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 253, 253, 253],\n",
       "        [255, 255, 255, ..., 253, 253, 253],\n",
       "        [255, 255, 255, ..., 253, 253, 253],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 250, 245, 240],\n",
       "        [255, 255, 255, ..., 251, 243, 235],\n",
       "        [255, 255, 255, ..., 240, 240, 245]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [254, 254, 254, ..., 255, 255, 255],\n",
       "        [254, 254, 254, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [252, 253, 254, ..., 254, 254, 255],\n",
       "        [253, 254, 254, ..., 254, 254, 255],\n",
       "        [250, 252, 244, ..., 254, 254, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [250, 253, 252, ..., 240, 242, 249],\n",
       "        [250, 254, 252, ..., 244, 245, 251],\n",
       "        [248, 250, 249, ..., 242, 245, 243]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 254, 254, 254],\n",
       "        [255, 255, 255, ..., 254, 254, 254],\n",
       "        [255, 255, 255, ..., 254, 254, 254],\n",
       "        ...,\n",
       "        [247, 247, 247, ..., 247, 250, 253],\n",
       "        [247, 246, 246, ..., 245, 247, 251],\n",
       "        [246, 244, 244, ..., 251, 247, 247]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 251, 253, 254],\n",
       "        [255, 255, 255, ..., 251, 252, 252],\n",
       "        [255, 255, 255, ..., 250, 250, 250],\n",
       "        ...,\n",
       "        [251, 251, 251, ..., 242, 242, 244],\n",
       "        [251, 251, 250, ..., 240, 242, 245],\n",
       "        [254, 253, 252, ..., 245, 243, 238]], dtype=uint8),\n",
       " array([[242, 239, 242, ..., 255, 255, 255],\n",
       "        [242, 240, 241, ..., 255, 255, 255],\n",
       "        [244, 240, 240, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [236, 234, 224, ..., 255, 255, 255],\n",
       "        [240, 236, 224, ..., 255, 255, 255],\n",
       "        [230, 231, 230, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[244, 239, 237, ..., 255, 255, 255],\n",
       "        [241, 237, 236, ..., 255, 255, 255],\n",
       "        [249, 244, 242, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [233, 237, 237, ..., 255, 255, 255],\n",
       "        [233, 238, 238, ..., 255, 255, 255],\n",
       "        [237, 236, 236, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[227, 219, 215, ..., 255, 255, 255],\n",
       "        [227, 221, 221, ..., 255, 255, 255],\n",
       "        [223, 219, 221, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [218, 215, 214, ..., 255, 255, 255],\n",
       "        [219, 217, 215, ..., 255, 255, 255],\n",
       "        [213, 215, 217, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 247, 243, ..., 255, 255, 255],\n",
       "        [255, 246, 241, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8),\n",
       " array([[255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        ...,\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255],\n",
       "        [255, 255, 255, ..., 255, 255, 255]], dtype=uint8)]"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "39bd6880-e4c2-41f3-8f8b-61bc37475d4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_path = './datasets/Oracle/4_Recognize/训练集/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "490d7edc-29ef-4eb4-875b-9c42c429f44c",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_path = './datasets/Oracle/4_Recognize/测试集/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "ac18e83d-7a7c-43bb-b494-854e2aa24af3",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_imgs = [plt.imread(os.path.join(test_path,img)) for img in os.listdir(test_path)] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d8e91379-0fee-49e8-af5b-c31767ea97b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Resize((256, 256)),\n",
    "    transforms.ToTensor(),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "481ed4f7-7824-404f-a067-7ced5ee2d7ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = datasets.ImageFolder(root=train_path, transform=transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "88f0557d-2d2e-4705-ab79-b0428bcf01a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "ad140e90-65d8-4e4b-b4aa-a416901c1bfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(data, num_epochs, loss_func, model, device=None):\n",
    "    if device is None:\n",
    "        device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    optimizer = torch.optim.AdamW(models.parameters(), lr=1e-5)\n",
    "    model.to(device)\n",
    "    for epochs in range(num_epochs):\n",
    "        progress_bar = tqdm(data, desc=f'Epoch {epochs+1}/{num_epochs}', leave=False)\n",
    "        for imgs, labels in progress_bar:\n",
    "            imgs = imgs.to(device)\n",
    "            labels = labels.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            out = model(imgs)\n",
    "            loss = loss_func(out, labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        print(f'Epoch {epochs+1}/{num_epochs}, Loss: {loss.item()}')\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "6eeec234-c8d2-494c-82cb-1cff787a826a",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = torch.nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "1f258d43-e0e3-4828-8a2f-cbb1bbc810b6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20, Loss: 4.339532375335693\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/20, Loss: 4.335596084594727\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/20, Loss: 4.316740989685059\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4/20, Loss: 4.320252418518066\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5/20, Loss: 4.270524024963379\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6/20, Loss: 4.32129430770874\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7/20, Loss: 4.303084850311279\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8/20, Loss: 4.329886436462402\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9/20, Loss: 4.329293727874756\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10/20, Loss: 4.3292012214660645\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 11/20, Loss: 4.330680847167969\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 12/20, Loss: 4.330737113952637\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 13/20, Loss: 4.330837249755859\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 14/20, Loss: 4.329710483551025\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 15/20, Loss: 4.330772399902344\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 16/20, Loss: 4.330124855041504\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 17/20, Loss: 4.330763816833496\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 18/20, Loss: 4.330377101898193\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19/20, Loss: 4.330754280090332\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                                                       "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 20/20, Loss: 4.3297295570373535\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "train(dataloader, 20, loss, models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "109185d7-b49a-4b22-b46b-eadf2fb2aae9",
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "Unexpected type <class 'numpy.ndarray'>",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[66], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mtransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_imgs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mD:\\Anaconda_envs\\envs\\pytorch\\lib\\site-packages\\torchvision\\transforms\\transforms.py:95\u001b[0m, in \u001b[0;36mCompose.__call__\u001b[1;34m(self, img)\u001b[0m\n\u001b[0;32m     93\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, img):\n\u001b[0;32m     94\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransforms:\n\u001b[1;32m---> 95\u001b[0m         img \u001b[38;5;241m=\u001b[39m \u001b[43mt\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     96\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m img\n",
      "File \u001b[1;32mD:\\Anaconda_envs\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mD:\\Anaconda_envs\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[1;32mD:\\Anaconda_envs\\envs\\pytorch\\lib\\site-packages\\torchvision\\transforms\\transforms.py:354\u001b[0m, in \u001b[0;36mResize.forward\u001b[1;34m(self, img)\u001b[0m\n\u001b[0;32m    346\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, img):\n\u001b[0;32m    347\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m    348\u001b[0m \u001b[38;5;124;03m    Args:\u001b[39;00m\n\u001b[0;32m    349\u001b[0m \u001b[38;5;124;03m        img (PIL Image or Tensor): Image to be scaled.\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    352\u001b[0m \u001b[38;5;124;03m        PIL Image or Tensor: Rescaled image.\u001b[39;00m\n\u001b[0;32m    353\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m--> 354\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minterpolation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mantialias\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mD:\\Anaconda_envs\\envs\\pytorch\\lib\\site-packages\\torchvision\\transforms\\functional.py:455\u001b[0m, in \u001b[0;36mresize\u001b[1;34m(img, size, interpolation, max_size, antialias)\u001b[0m\n\u001b[0;32m    449\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m max_size \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(size) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m    450\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m    451\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmax_size should only be passed if size specifies the length of the smaller edge, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    452\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mi.e. size should be an int or a sequence of length 1 in torchscript mode.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    453\u001b[0m         )\n\u001b[1;32m--> 455\u001b[0m _, image_height, image_width \u001b[38;5;241m=\u001b[39m \u001b[43mget_dimensions\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    456\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(size, \u001b[38;5;28mint\u001b[39m):\n\u001b[0;32m    457\u001b[0m     size \u001b[38;5;241m=\u001b[39m [size]\n",
      "File \u001b[1;32mD:\\Anaconda_envs\\envs\\pytorch\\lib\\site-packages\\torchvision\\transforms\\functional.py:79\u001b[0m, in \u001b[0;36mget_dimensions\u001b[1;34m(img)\u001b[0m\n\u001b[0;32m     76\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(img, torch\u001b[38;5;241m.\u001b[39mTensor):\n\u001b[0;32m     77\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m F_t\u001b[38;5;241m.\u001b[39mget_dimensions(img)\n\u001b[1;32m---> 79\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF_pil\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_dimensions\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mD:\\Anaconda_envs\\envs\\pytorch\\lib\\site-packages\\torchvision\\transforms\\_functional_pil.py:31\u001b[0m, in \u001b[0;36mget_dimensions\u001b[1;34m(img)\u001b[0m\n\u001b[0;32m     29\u001b[0m     width, height \u001b[38;5;241m=\u001b[39m img\u001b[38;5;241m.\u001b[39msize\n\u001b[0;32m     30\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m [channels, height, width]\n\u001b[1;32m---> 31\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(img)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
      "\u001b[1;31mTypeError\u001b[0m: Unexpected type <class 'numpy.ndarray'>"
     ]
    }
   ],
   "source": [
    "transform(test_imgs[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "433aee70-82f0-44fb-8a67-f1bbce8e3301",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "Unexpected type <class 'numpy.ndarray'>",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[65], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m models(\u001b[43mtransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_imgs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m)\n",
      "File \u001b[1;32mD:\\Anaconda_envs\\envs\\pytorch\\lib\\site-packages\\torchvision\\transforms\\transforms.py:95\u001b[0m, in \u001b[0;36mCompose.__call__\u001b[1;34m(self, img)\u001b[0m\n\u001b[0;32m     93\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, img):\n\u001b[0;32m     94\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransforms:\n\u001b[1;32m---> 95\u001b[0m         img \u001b[38;5;241m=\u001b[39m \u001b[43mt\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     96\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m img\n",
      "File \u001b[1;32mD:\\Anaconda_envs\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mD:\\Anaconda_envs\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\modules\\module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[1;32mD:\\Anaconda_envs\\envs\\pytorch\\lib\\site-packages\\torchvision\\transforms\\transforms.py:354\u001b[0m, in \u001b[0;36mResize.forward\u001b[1;34m(self, img)\u001b[0m\n\u001b[0;32m    346\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, img):\n\u001b[0;32m    347\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m    348\u001b[0m \u001b[38;5;124;03m    Args:\u001b[39;00m\n\u001b[0;32m    349\u001b[0m \u001b[38;5;124;03m        img (PIL Image or Tensor): Image to be scaled.\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    352\u001b[0m \u001b[38;5;124;03m        PIL Image or Tensor: Rescaled image.\u001b[39;00m\n\u001b[0;32m    353\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m--> 354\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minterpolation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mantialias\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mD:\\Anaconda_envs\\envs\\pytorch\\lib\\site-packages\\torchvision\\transforms\\functional.py:455\u001b[0m, in \u001b[0;36mresize\u001b[1;34m(img, size, interpolation, max_size, antialias)\u001b[0m\n\u001b[0;32m    449\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m max_size \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(size) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m    450\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m    451\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmax_size should only be passed if size specifies the length of the smaller edge, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    452\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mi.e. size should be an int or a sequence of length 1 in torchscript mode.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    453\u001b[0m         )\n\u001b[1;32m--> 455\u001b[0m _, image_height, image_width \u001b[38;5;241m=\u001b[39m \u001b[43mget_dimensions\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    456\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(size, \u001b[38;5;28mint\u001b[39m):\n\u001b[0;32m    457\u001b[0m     size \u001b[38;5;241m=\u001b[39m [size]\n",
      "File \u001b[1;32mD:\\Anaconda_envs\\envs\\pytorch\\lib\\site-packages\\torchvision\\transforms\\functional.py:79\u001b[0m, in \u001b[0;36mget_dimensions\u001b[1;34m(img)\u001b[0m\n\u001b[0;32m     76\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(img, torch\u001b[38;5;241m.\u001b[39mTensor):\n\u001b[0;32m     77\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m F_t\u001b[38;5;241m.\u001b[39mget_dimensions(img)\n\u001b[1;32m---> 79\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF_pil\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_dimensions\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mD:\\Anaconda_envs\\envs\\pytorch\\lib\\site-packages\\torchvision\\transforms\\_functional_pil.py:31\u001b[0m, in \u001b[0;36mget_dimensions\u001b[1;34m(img)\u001b[0m\n\u001b[0;32m     29\u001b[0m     width, height \u001b[38;5;241m=\u001b[39m img\u001b[38;5;241m.\u001b[39msize\n\u001b[0;32m     30\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m [channels, height, width]\n\u001b[1;32m---> 31\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(img)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
      "\u001b[1;31mTypeError\u001b[0m: Unexpected type <class 'numpy.ndarray'>"
     ]
    }
   ],
   "source": [
    "models()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
